Skip to content

Add router replay for MoE models#2101

Merged
Phlip79 merged 24 commits intoNVIDIA:mainfrom
litianjian:feat/router_replay
Jan 27, 2026
Merged

Add router replay for MoE models#2101
Phlip79 merged 24 commits intoNVIDIA:mainfrom
litianjian:feat/router_replay

Conversation

@litianjian
Copy link
Copy Markdown
Contributor

@litianjian litianjian commented Nov 3, 2025

What does this PR do ?

This PR introduces a "Router Replay" feature for Mixture-of-Experts (MoE) layers. This functionality provides a deterministic routing mechanism, which is essential for debugging, controlled experimentation, and reproducing model behavior.

Inspired by recent approaches in stabilizing MoE models Router Replay(R2) and Rollout Router Replay(R3), RouterReplay implementation allows developers to easily save and set the router's replay information, providing precise control over the expert selection process to mitigate routing inconsistencie

Implementation Details:

  1. Configuration Flag:
  • A new boolean flag, enable_routing_replay , has been added to TransformerConfig . This allows users to enable or disable the feature globally.
  1. RouterReplay Class:
  • A new class, RouterReplay , is introduced in moe_utils.py to manage the state and data for the replay functionality.
  • It operates in three modes, defined by the RoutingMode enum:
    • None : The default state, where standard routing occurs.
    • RECORD : The router records the expert indices ( topk_ids ) for each token.
    • REPLAY : The router bypasses the standard Top-K logic and instead uses the previously recorded expert indices to route tokens.
  1. Integration with TopKRouter :
  • In router.py , the TopKRouter now initializes a RouterReplay instance if config.enable_routing_replay is True .
  • This instance is then passed to the topk_routing_with_score_function during the routing process.
  1. Core Logic in moe_utils.py :
  • The topk_routing_with_score_function has been updated to handle the router_replay object.
  • When the mode is RECORD , it captures the topk_ids after they are computed.
  • When the mode is REPLAY , it retrieves the stored topk_ids from the RouterReplay object and uses them to construct the routing gates and probabilities, effectively bypassing the dynamic routing calculation.

⚠️ For major changes (either in lines of code or in its impact), please make sure to first share discuss a design-doc with the team.

Contribution process

flowchart LR
    A[Pre-checks] --> B[PR Tests]
    subgraph Code Review/Approval
        C1[Expert Review] --> C2[Final Review]
    end
    B --> C1
    C2 --> D[Merge]
Loading

Pre-checks

  • I want this PR in a versioned release and have added the appropriate Milestone (e.g., Core 0.8)
  • I have added relevant unit tests
  • I have added relevant functional tests
  • I have added proper typing to my code Typing guidelines
  • I have added relevant documentation
  • I have run the autoformatter.sh on my PR

Code review

The following process is enforced via the CODEOWNERS file for changes into megatron/core. For changes outside of megatron/core, it is up to the PR author whether or not to tag the Final Reviewer team.

For MRs into `main` branch

(Step 1): Add PR label Expert Review

(Step 2): Collect the expert reviewers reviews

  1. Attach the Expert Review label when your PR is ready for review.
  2. GitHub auto-assigns expert reviewers based on your changes. They will get notified and pick up your PR soon.

⚠️ Only proceed to the next step once all reviewers have approved, merge-conflict are resolved and the CI is passing.
Final Review might get declined if these requirements are not fulfilled.

(Step 3): Final Review

  1. Add Final Review label
  2. GitHub auto-assigns final reviewers based on your changes. They will get notified and pick up your PR soon.

(Optional Step 4): Cherry-pick into release branch

If this PR also needs to be merged into core_r* release branches, after this PR has been merged, select Cherry-pick to open a new PR into the release branch.

For MRs into `dev` branch The proposed review process for `dev` branch is under active discussion.

MRs are mergable after one approval by either eharper@nvidia.com or zijiey@nvidia.com.

Merging your PR

Any member of core-adlr and core-nemo will be able to merge your PR.

@litianjian litianjian requested review from a team as code owners November 3, 2025 10:15
@copy-pr-bot
Copy link
Copy Markdown

copy-pr-bot Bot commented Nov 3, 2025

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

@yanring yanring added the Expert Review [deprecated] Apply this label to indicate that your PR is ready for expert review. label Nov 7, 2025
Comment thread megatron/core/transformer/moe/moe_utils.py Outdated
Comment thread megatron/core/transformer/moe/moe_utils.py Outdated
Comment thread megatron/core/transformer/moe/moe_utils.py Outdated
Comment thread megatron/core/transformer/moe/moe_utils.py Outdated
@ISEEKYAN ISEEKYAN requested review from a team as code owners November 20, 2025 06:03
@ISEEKYAN
Copy link
Copy Markdown
Contributor

@litianjian it is better if we add a doc to give a minimal demo for R2 as guidance

@litianjian litianjian force-pushed the feat/router_replay branch 2 times, most recently from 8e9ef4b to f6eb81d Compare November 24, 2025 06:25
ISEEKYAN pushed a commit to verl-project/verl that referenced this pull request Dec 4, 2025
### What does this PR do?

This PR introduces a draft **Router Replay** support into Verl.
Inspired by the recent research in **MoE Reinforcement
Learning**([2510.11370](https://arxiv.org/abs/2510.11370),
[2507.18071](https://arxiv.org/abs/2507.18071)), this implementation
supports **Router Replay (R2)** and **Rollout Router Replay (R3)**.
R2 allows recording routing token selection during` log probability
computation` and replaying expert selection during policy update. R3
enables recording during `model inference` and replaying during RL
post-training.

The initial version supports **Router Replay** with `Megatron` backend,
including comprehensive support for distributed training strategies
(**DP, TP, EP, ETP, PP, and Re-compute**).


The current implementation uses a patch-based approach. Once the
upstream PR
[NVIDIA/Megatron-LM#2101](NVIDIA/Megatron-LM#2101)
is merged or provides corresponding interfaces, the patch can be removed
and replaced with official API integration.

## Usage Tutorial

### Basic Configuration
To enable Router Replay functionality, add the following configuration
to your trainer config:
#### Method 1: Trainer Configuration
Add the following configuration to your trainer config:

```yaml
router_replay:
  enabled: true
  mode: "R2"  # Options: "R2", "R3"
```

#### Method 2: Launch Script Configuration
Add the following parameter to your launch script:

```bash
# In your launch script
actor_rollout_ref.actor.router_replay.mode="R2"
```

### R2 Mode Usage
1. **Enable R2 mode** in configuration
2. **Record phase**: During log probability computation, routing
selections are automatically recorded
3. **Replay phase**: During policy update, recorded expert selections
are replayed

### R3 Mode Usage
1. **Enable R3 mode** in configuration
2. **Record phase**: During model inference, routing decisions are
captured
3. **Replay phase**: During RL post-training, recorded routing data is
used
4. 
## In Progress
R2
- [ ]  FSDP backend

R3
- [x] vLLM Rollout
- [ ] Sglang Rollout

---------

Co-authored-by: litianjian <litianjian@bytedance.com>
Co-authored-by: zhangbiao.168 <zhangbiao.168@bytedance.com>
@ISEEKYAN ISEEKYAN enabled auto-merge December 9, 2025 06:23
@sidsingh-nvidia
Copy link
Copy Markdown
Contributor

/ok to test 3260df1

@sidsingh-nvidia
Copy link
Copy Markdown
Contributor

@jon-barker Does this look okay now?

@Phlip79
Copy link
Copy Markdown
Member

Phlip79 commented Jan 26, 2026

/ok to test 02a65df

@Phlip79
Copy link
Copy Markdown
Member

Phlip79 commented Jan 26, 2026

/ok to test dc2d3b2

Copy link
Copy Markdown
Contributor

@jon-barker jon-barker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm - thanks

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

community-request complexity: medium Expert Review [deprecated] Apply this label to indicate that your PR is ready for expert review. Final Review PR is in the "final review" stage module: moe Run functional tests

Projects

None yet

Development

Successfully merging this pull request may close these issues.